{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Federated Learning - SMS spam prediction with a GRU model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "**NOTE**: At the time of running this notebook, we were running the grid components in background mode.\n",
    "\n",
    "**NOTE**: Components:\n",
    "\n",
    "* Grid Gateway (http://localhost:8080)\n",
    "* Grid Node Bob (http://localhost:3000)\n",
    "* Grid Node Alice (http://localhost:3001)\n",
    "\n",
    " \n",
    "To **start the gateway**:\n",
    "* ```cd gateway```\n",
    "* ```python gateway.py --start_local_db --port=5000```\n",
    "\n",
    "\n",
    "To **start one grid node**:\n",
    "\n",
    "* ```cd app/websocket/```\n",
    "\n",
    "* ```python websocket_app.py --start_local_db --id=alice --port=3001 --gateway_url=http://localhost:5000```\n",
    "\n",
    "This notebook was made based on [Federated SMS Spam prediction](https://github.com/OpenMined/PySyft/tree/master/examples/tutorials/advanced/Federated%20SMS%20Spam%20prediction).\n",
    "\n",
    "Authors:\n",
    "* André Macedo Farias: Github: [@andrelmfarias](https://github.com/andrelmfarias) | Twitter: [@andrelmfarias](https://twitter.com/andrelmfarias)\n",
    "* George Muraru: Github [@gmuraru](https://github/com/gmuraru) | Twitter: [@georgemuraru](https://twitter.com/georgemuraru) | Facebook: [@George Cristian Muraru](https://www.facebook.com/georgecmuraru)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Useful imports"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-06-14T22:55:56.381002Z",
     "start_time": "2019-06-14T22:55:52.562283Z"
    }
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "import torch\n",
    "\n",
    "import syft as sy\n",
    "from syft.workers.node_client import NodeClient"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<h2>Setup config</h2>\n",
    "Init hook, connect with grid nodes, etc..."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "hook = sy.TorchHook(torch)\n",
    "\n",
    "# Connect directly to grid nodes\n",
    "nodes = [\"ws://localhost:3000/\",\n",
    "         \"ws://localhost:3001/\"]\n",
    "\n",
    "compute_nodes = []\n",
    "for node in nodes:\n",
    "    compute_nodes.append(NodeClient(hook, node))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Load Dataset"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 1) Download (if not present) and preprocess dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "scrolled": true
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Not downloading the dataset because it was already downloaded\n"
     ]
    }
   ],
   "source": [
    "import os\n",
    "import urllib.request\n",
    "import pathlib\n",
    "from zipfile import ZipFile\n",
    "\n",
    "URL = \"https://archive.ics.uci.edu/ml/machine-learning-databases/00228/smsspamcollection.zip\"\n",
    "DATASET_NAME = \"smsspamcollection\"\n",
    "\n",
    "def dataset_exists():\n",
    "    return os.path.isfile('./data/inputs.npy') and \\\n",
    "    os.path.isfile('./data/labels.npy')\n",
    "    \n",
    "if not dataset_exists():\n",
    "    #If the dataset does not already exist, let's download the dataset directly from the URL where it is hosted\n",
    "    print('Downloading the dataset with urllib2 to the current directory...')\n",
    "    pathlib.Path(\"data\").mkdir(exist_ok=True)\n",
    "    urllib.request.urlretrieve(URL, './data/data.zip')\n",
    "    print(\"The dataset was successfully downloaded\")\n",
    "    print(\"Unzipping the dataset...\")\n",
    "    with ZipFile('./data/data.zip', 'r') as zipObj:\n",
    "       # Extract all the contents of the zip file in current directory\n",
    "       zipObj.extractall(\"./data\")\n",
    "    print(\"Dataset successfully unzipped\")\n",
    "    \n",
    "    from preprocess import preprocess_spam\n",
    "\n",
    "    preprocess_spam()\n",
    "else:\n",
    "    print(\"Not downloading the dataset because it was already downloaded\")\n",
    "    \n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 2) Loading data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "As we are most interested in the usage of PySyft and Federated Learning, I will skip the text-preprocessing part of the project. If you are interested in how I performed the preprocessing of the raw dataset you can take a look on the script [preprocess.py](https://github.com/OpenMined/PyGrid/tree/master/examples/data/SMS-spam/preprocess.py).\n",
    "\n",
    "Each data point of the `inputs.npy` dataset correspond to an array of 30 tokens obtained form each message (padded at left or truncated at right)\n",
    "\n",
    "The `label.npy` dataset has the following unique values: `1` for `spam` and `0` for `non-spam`"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "ExecuteTime": {
     "end_time": "2019-06-14T22:59:06.345073Z",
     "start_time": "2019-06-14T22:59:06.322378Z"
    }
   },
   "outputs": [],
   "source": [
    "inputs = np.load('./data/inputs.npy')\n",
    "labels = np.load('./data/labels.npy')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "datasets_spam = torch.split(torch.tensor(inputs), int(len(inputs) / len(compute_nodes)), dim=0 ) #tuple of chunks (dataset / number of nodes)\n",
    "labels_spam = torch.split(torch.tensor(labels), int(len(labels) / len(compute_nodes)), dim=0 )  #tuple of chunks (labels / number of nodes)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<h2>3) Tagging tensors</h2>\n",
    "The code below will add a tag (of your choice) to the data that will be sent to grid nodes. This tag is important as the gateway will need it to retrieve this data later."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "tag_img = []\n",
    "tag_label = []\n",
    "\n",
    "\n",
    "for i in range(len(compute_nodes)):\n",
    "    tag_img.append(datasets_spam[i].tag(\"#X\", \"#spam\", \"#dataset\").describe(\"The input datapoints to the SPAM dataset.\"))\n",
    "    tag_label.append(labels_spam[i].tag(\"#Y\", \"#spam\", \"#dataset\").describe(\"The input labels to the SPAM dataset.\"))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "<h2> 4) Sending our tensors to grid nodes</h2>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "X tensor pointers:  (Wrapper)>[PointerTensor | me:4887477248 -> Bob:58176119719]\n",
      "\tTags: #dataset #X #spam \n",
      "\tShape: torch.Size([2786, 30])\n",
      "\tDescription: The input datapoints to the SPAM dataset.... (Wrapper)>[PointerTensor | me:7448093939 -> Bob:34347751563]\n",
      "\tTags: #dataset #spam #Y \n",
      "\tShape: torch.Size([2786])\n",
      "\tDescription: The input labels to the SPAM dataset....\n",
      "X tensor pointers:  (Wrapper)>[PointerTensor | me:15289568769 -> Alice:16627579692]\n",
      "\tTags: #dataset #X #spam \n",
      "\tShape: torch.Size([2786, 30])\n",
      "\tDescription: The input datapoints to the SPAM dataset.... (Wrapper)>[PointerTensor | me:42923736144 -> Alice:89030703951]\n",
      "\tTags: #dataset #spam #Y \n",
      "\tShape: torch.Size([2786])\n",
      "\tDescription: The input labels to the SPAM dataset....\n"
     ]
    }
   ],
   "source": [
    "# NOTE: For some reason, there is strange behavior when trying to send within a loop.\n",
    "# Ex : tag_x[i].send(compute_nodes[i])\n",
    "# When resolved, this should be updated.\n",
    "\n",
    "for i in range(len(compute_nodes)):\n",
    "    shared_x = tag_img[i].send(compute_nodes[i], garbage_collect_data=False)\n",
    "    shared_y = tag_label[i].send(compute_nodes[i], garbage_collect_data=False)\n",
    "    print(\"X tensor pointers: \", shared_x, shared_y)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Disconnect Nodes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i in range(len(compute_nodes)):\n",
    "    compute_nodes[i].close()"
   ]
  }
 ],
 "metadata": {
  "hide_input": false,
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.3"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": true,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  },
  "varInspector": {
   "cols": {
    "lenName": 16,
    "lenType": 16,
    "lenVar": 40
   },
   "kernels_config": {
    "python": {
     "delete_cmd_postfix": "",
     "delete_cmd_prefix": "del ",
     "library": "var_list.py",
     "varRefreshCmd": "print(var_dic_list())"
    },
    "r": {
     "delete_cmd_postfix": ") ",
     "delete_cmd_prefix": "rm(",
     "library": "var_list.r",
     "varRefreshCmd": "cat(var_dic_list()) "
    }
   },
   "types_to_exclude": [
    "module",
    "function",
    "builtin_function_or_method",
    "instance",
    "_Feature"
   ],
   "window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
